-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add DistributedEmbedding example for TPU on TensorFlow. #2174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request adds a comprehensive example demonstrating distributed embedding training on TensorFlow using TPU SparseCore. It provides a practical guide for leveraging keras_rs.layers.DistributedEmbedding for large-scale recommendation systems, complementing the existing JAX-based example.
Highlights
- New TensorFlow DistributedEmbedding Example: Introduces a new example (distributed_embedding_tf.py) demonstrating the use of keras_rs.layers.DistributedEmbedding for movie ranking on TensorFlow with TPU SparseCore.
- Jupyter Notebook and Markdown Versions: Accompanying Jupyter Notebook (.ipynb) and Markdown (.md) versions of the example are added for easier consumption and documentation.
- Example Integration: The new TensorFlow example is integrated into the rs_master.py script, making it discoverable alongside existing examples.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a new example demonstrating the use of keras_rs.layers.DistributedEmbedding with TensorFlow on TPUs. The example is well-structured and provides a clear walkthrough of setting up the TPU strategy, preparing the dataset, configuring the distributed embedding layer, and training a ranking model. I've found a couple of potential issues related to the configuration of FeatureConfig and the usage of the embedding layer's output, which could lead to runtime errors. My detailed comments are below.
802bb88 to
4dca226
Compare
abheesht17
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! It's great news that this works on TPUs :)
|
|
||
| """shell | ||
| pip install -U -q tensorflow-tpu==2.19.1 | ||
| pip install -q keras-rs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'd mentioned this before: should we add the optional square bracket thing to KerasRS setup files, like so: pip install -q keras-rs[tpu]? Or pip install -q keras-rs[dist-emb-tpu] or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question, we should experiment with this. Part of the complication is the combinatorial of backends and hardware. One other issue I have faced is that some packages clash in the version they want for their dependencies (protobuf or keras for instance).
I think we would have to do:
keras-rs[tf-tpu](addstensorflow-tpu==2.19.1)keras-rs[jax-tpu](addsjax-tpu-embeddingandjax[tpu])
| SparseCore chips of all the available TPUs. | ||
| """ | ||
|
|
||
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[knowledge] I'm assuming this won't work in the multi-host case, right? Since tpu = "local"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. Is it just a question of passing the name of the TPU cluster? If so, I can add a variable and a comment explaining how to do it. But I haven't tested it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think you can pass the GCP zone, project, etc. too. I just asked this question for knowledge. We don't have to add since we haven't tested it yet. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't get my hands on a multi-host v6e right now. We can submit as-is and revisit.
| topology = tf.tpu.experimental.initialize_tpu_system(resolver) | ||
| tpu_metadata = resolver.get_tpu_system_metadata() | ||
|
|
||
| device_assignment = tf.tpu.experimental.DeviceAssignment.build( | ||
| topology, num_replicas=tpu_metadata.num_cores | ||
| ) | ||
| strategy = tf.distribute.TPUStrategy( | ||
| resolver, experimental_device_assignment=device_assignment | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why we have to do this instead of using MirroredStrategy? Maybe, we can add a note here for the reader as to why this is necessary? What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, that's what the text lines 64-69 are for. If you're on TPU, you have to use a TPUStrategy.
https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy
This strategy is typically used for training on one machine with multiple GPUs. For TPUs, use tf.distribute.TPUStrategy
I could add a link to the TensorFlow documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could add a link to the TensorFlow documentation.
Yeah, let's add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
4dca226 to
2196a37
Compare
abheesht17
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, replied to some comments
| SparseCore chips of all the available TPUs. | ||
| """ | ||
|
|
||
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think you can pass the GCP zone, project, etc. too. I just asked this question for knowledge. We don't have to add since we haven't tested it yet. What do you think?
| topology = tf.tpu.experimental.initialize_tpu_system(resolver) | ||
| tpu_metadata = resolver.get_tpu_system_metadata() | ||
|
|
||
| device_assignment = tf.tpu.experimental.DeviceAssignment.build( | ||
| topology, num_replicas=tpu_metadata.num_cores | ||
| ) | ||
| strategy = tf.distribute.TPUStrategy( | ||
| resolver, experimental_device_assignment=device_assignment | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could add a link to the TensorFlow documentation.
Yeah, let's add it?
2196a37 to
93145b5
Compare
This was run on a cloud TPU v6e-1.
Also tweaked some comments in the JAX DistributedEmbedding example.